from __future__ import annotations
import asyncio
from typing import Any, Dict
from sqlalchemy import create_engine, text, Column, String, Text
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from pgvector.sqlalchemy import Vector
from langbot.pkg.vector.vdb import VectorDatabase
from langbot.pkg.core import app
import uuid

Base = declarative_base()


class PgVectorEntry(Base):
    """SQLAlchemy model for pgvector entries"""
    __tablename__ = 'langbot_vectors'

    id = Column(String, primary_key=True)
    collection = Column(String, index=True, nullable=False)
    embedding = Column(Vector(1536))  # Default dimension, will be created dynamically
    text = Column(Text)
    file_id = Column(String, index=True)
    chunk_uuid = Column(String)


class PgVectorDatabase(VectorDatabase):
    """PostgreSQL with pgvector extension database implementation"""

    def __init__(
        self,
        ap: app.Application,
        connection_string: str = None,
        host: str = "localhost",
        port: int = 5432,
        database: str = "langbot",
        user: str = "postgres",
        password: str = "postgres"
    ):
        """Initialize pgvector database

        Args:
            ap: Application instance
            connection_string: Full PostgreSQL connection string (overrides other params)
            host: PostgreSQL host
            port: PostgreSQL port
            database: Database name
            user: Database user
            password: Database password
        """
        self.ap = ap

        # Build connection string if not provided
        if connection_string:
            self.connection_string = connection_string
        else:
            self.connection_string = (
                f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
            )

        self.async_connection_string = self.connection_string.replace(
            "postgresql://", "postgresql+asyncpg://"
        ).replace(
            "postgresql+psycopg://", "postgresql+asyncpg://"
        )

        self.engine = None
        self.async_engine = None
        self.SessionLocal = None
        self.AsyncSessionLocal = None
        self._collections = set()
        self._initialize_db()

    def _initialize_db(self):
        """Initialize database connection and create tables"""
        try:
            # Create async engine for async operations
            self.async_engine = create_async_engine(
                self.async_connection_string,
                echo=False,
                pool_pre_ping=True
            )
            self.AsyncSessionLocal = async_sessionmaker(
                self.async_engine,
                class_=AsyncSession,
                expire_on_commit=False
            )

            # Create sync engine for table creation
            sync_connection_string = self.connection_string.replace(
                "postgresql+asyncpg://", "postgresql+psycopg://"
            )
            self.engine = create_engine(sync_connection_string, echo=False)

            # Create pgvector extension and tables
            with self.engine.connect() as conn:
                # Enable pgvector extension
                conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
                conn.commit()

            # Create tables
            Base.metadata.create_all(self.engine)

            self.ap.logger.info(f"Connected to PostgreSQL with pgvector")
        except Exception as e:
            self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}")
            raise

    async def get_or_create_collection(self, collection: str):
        """Get or create a collection (logical grouping in pgvector)

        Args:
            collection: Collection name (knowledge base UUID)
        """
        # In pgvector, collections are logical - we just track them
        if collection not in self._collections:
            self._collections.add(collection)
            self.ap.logger.info(f"Registered pgvector collection '{collection}'")
        return collection

    async def add_embeddings(
        self,
        collection: str,
        ids: list[str],
        embeddings_list: list[list[float]],
        metadatas: list[dict[str, Any]],
    ) -> None:
        """Add vector embeddings to pgvector

        Args:
            collection: Collection name
            ids: List of unique IDs for each vector
            embeddings_list: List of embedding vectors
            metadatas: List of metadata dictionaries
        """
        await self.get_or_create_collection(collection)

        async with self.AsyncSessionLocal() as session:
            try:
                for i, vector_id in enumerate(ids):
                    metadata = metadatas[i] if i < len(metadatas) else {}

                    entry = PgVectorEntry(
                        id=vector_id,
                        collection=collection,
                        embedding=embeddings_list[i],
                        text=metadata.get("text", ""),
                        file_id=metadata.get("file_id", ""),
                        chunk_uuid=metadata.get("uuid", "")
                    )
                    session.add(entry)

                await session.commit()
                self.ap.logger.info(
                    f"Added {len(ids)} embeddings to pgvector collection '{collection}'"
                )
            except Exception as e:
                await session.rollback()
                self.ap.logger.error(f"Error adding embeddings to pgvector: {e}")
                raise

    async def search(
        self, collection: str, query_embedding: list[float], k: int = 5
    ) -> Dict[str, Any]:
        """Search for similar vectors using cosine distance

        Args:
            collection: Collection name
            query_embedding: Query vector
            k: Number of top results to return

        Returns:
            Dictionary with search results in Chroma-compatible format
        """
        await self.get_or_create_collection(collection)

        async with self.AsyncSessionLocal() as session:
            try:
                # Use cosine distance for similarity search
                from sqlalchemy import select, func

                # Query for similar vectors
                stmt = (
                    select(
                        PgVectorEntry.id,
                        PgVectorEntry.text,
                        PgVectorEntry.file_id,
                        PgVectorEntry.chunk_uuid,
                        PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance')
                    )
                    .filter(PgVectorEntry.collection == collection)
                    .order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
                    .limit(k)
                )

                result = await session.execute(stmt)
                rows = result.fetchall()

                # Convert to Chroma-compatible format
                ids = []
                distances = []
                metadatas = []

                for row in rows:
                    ids.append(row.id)
                    distances.append(float(row.distance))
                    metadatas.append({
                        "text": row.text or "",
                        "file_id": row.file_id or "",
                        "uuid": row.chunk_uuid or ""
                    })

                result_dict = {
                    "ids": [ids],
                    "distances": [distances],
                    "metadatas": [metadatas]
                }

                self.ap.logger.info(
                    f"pgvector search in '{collection}' returned {len(ids)} results"
                )
                return result_dict

            except Exception as e:
                self.ap.logger.error(f"Error searching pgvector: {e}")
                raise

    async def delete_by_file_id(self, collection: str, file_id: str) -> None:
        """Delete vectors by file_id

        Args:
            collection: Collection name
            file_id: File ID to filter deletion
        """
        await self.get_or_create_collection(collection)

        async with self.AsyncSessionLocal() as session:
            try:
                from sqlalchemy import delete

                stmt = delete(PgVectorEntry).where(
                    PgVectorEntry.collection == collection,
                    PgVectorEntry.file_id == file_id
                )
                await session.execute(stmt)
                await session.commit()

                self.ap.logger.info(
                    f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}"
                )
            except Exception as e:
                await session.rollback()
                self.ap.logger.error(f"Error deleting from pgvector: {e}")
                raise

    async def delete_collection(self, collection: str):
        """Delete all vectors in a collection

        Args:
            collection: Collection name to delete
        """
        if collection in self._collections:
            self._collections.remove(collection)

        async with self.AsyncSessionLocal() as session:
            try:
                from sqlalchemy import delete

                stmt = delete(PgVectorEntry).where(
                    PgVectorEntry.collection == collection
                )
                await session.execute(stmt)
                await session.commit()

                self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
            except Exception as e:
                await session.rollback()
                self.ap.logger.error(f"Error deleting pgvector collection: {e}")
                raise

    async def close(self):
        """Close database connections"""
        if self.async_engine:
            await self.async_engine.dispose()
        if self.engine:
            self.engine.dispose()
